#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""SDK Fn Harness entry point."""

# pytype: skip-file

import importlib
import json
import logging
import os
import re
import signal
import sys
import time
import traceback

from google.protobuf import text_format

from apache_beam.internal import pickler
from apache_beam.io import filesystems
from apache_beam.options.pipeline_options import DebugOptions
from apache_beam.options.pipeline_options import GoogleCloudOptions
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import ProfilingOptions
from apache_beam.options.pipeline_options import SetupOptions
from apache_beam.options.pipeline_options import WorkerOptions
from apache_beam.options.value_provider import RuntimeValueProvider
from apache_beam.portability.api import endpoints_pb2
from apache_beam.runners.internal import names
from apache_beam.runners.worker.data_sampler import DataSampler
from apache_beam.runners.worker.log_handler import FnApiLogRecordHandler
from apache_beam.runners.worker.sdk_worker import SdkHarness
from apache_beam.utils import profiler

_LOGGER = logging.getLogger(__name__)
_ENABLE_GOOGLE_CLOUD_PROFILER = 'enable_google_cloud_profiler'
_FN_LOG_HANDLER = None


def _import_beam_plugins(plugins):
  for plugin in plugins:
    try:
      importlib.import_module(plugin)
      _LOGGER.debug('Imported beam-plugin %s', plugin)
    except ImportError:
      try:
        _LOGGER.debug((
            "Looks like %s is not a module. "
            "Trying to import it assuming it's a class"),
                      plugin)
        module, _ = plugin.rsplit('.', 1)
        importlib.import_module(module)
        _LOGGER.debug('Imported %s for beam-plugin %s', module, plugin)
      except ImportError as exc:
        _LOGGER.warning('Failed to import beam-plugin %s', plugin, exc_info=exc)


def create_harness(environment, dry_run=False):
  """Creates SDK Fn Harness."""

  deferred_exception = None
  if 'LOGGING_API_SERVICE_DESCRIPTOR' in environment:
    try:
      logging_service_descriptor = endpoints_pb2.ApiServiceDescriptor()
      text_format.Merge(
          environment['LOGGING_API_SERVICE_DESCRIPTOR'],
          logging_service_descriptor)

      # Send all logs to the runner.
      fn_log_handler = FnApiLogRecordHandler(logging_service_descriptor)
      logging.getLogger().addHandler(fn_log_handler)
      _LOGGER.info('Logging handler created.')
    except Exception:
      _LOGGER.error(
          "Failed to set up logging handler, continuing without.",
          exc_info=True)
      fn_log_handler = None
  else:
    fn_log_handler = None

  pipeline_options_dict = _load_pipeline_options(
      environment.get('PIPELINE_OPTIONS'))
  default_log_level = _get_log_level_from_options_dict(pipeline_options_dict)
  logging.getLogger().setLevel(default_log_level)
  _set_log_level_overrides(pipeline_options_dict)

  # These are used for dataflow templates.
  RuntimeValueProvider.set_runtime_options(pipeline_options_dict)
  sdk_pipeline_options = PipelineOptions.from_dictionary(pipeline_options_dict)
  filesystems.FileSystems.set_options(sdk_pipeline_options)
  pickle_library = sdk_pipeline_options.view_as(SetupOptions).pickle_library
  pickler.set_library(pickle_library)

  semi_persistent_directory = environment.get('SEMI_PERSISTENT_DIRECTORY', None)
  runner_capabilities = frozenset(
      environment.get('RUNNER_CAPABILITIES', '').split())

  _LOGGER.info('semi_persistent_directory: %s', semi_persistent_directory)
  _worker_id = environment.get('WORKER_ID', None)

  if pickle_library != pickler.USE_CLOUDPICKLE:
    try:
      _load_main_session(semi_persistent_directory)
    except LoadMainSessionException:
      exception_details = traceback.format_exc()
      _LOGGER.error(
          'Could not load main session: %s', exception_details, exc_info=True)
      raise
    except Exception:  # pylint: disable=broad-except
      summary = (
          "Could not load main session. Inspect which external dependencies "
          "are used in the main module of your pipeline. Verify that "
          "corresponding packages are installed in the pipeline runtime "
          "environment and their installed versions match the versions used in "
          "pipeline submission environment. For more information, see: https://"
          "beam.apache.org/documentation/sdks/python-pipeline-dependencies/")
      _LOGGER.error(summary, exc_info=True)
      exception_details = traceback.format_exc()
      deferred_exception = LoadMainSessionException(
          f"{summary} {exception_details}")

  _LOGGER.info(
      'Pipeline_options: %s',
      sdk_pipeline_options.get_all_options(drop_default=True))
  control_service_descriptor = endpoints_pb2.ApiServiceDescriptor()
  status_service_descriptor = endpoints_pb2.ApiServiceDescriptor()
  text_format.Merge(
      environment['CONTROL_API_SERVICE_DESCRIPTOR'], control_service_descriptor)
  if 'STATUS_API_SERVICE_DESCRIPTOR' in environment:
    text_format.Merge(
        environment['STATUS_API_SERVICE_DESCRIPTOR'], status_service_descriptor)
  # TODO(robertwb): Support authentication.
  assert not control_service_descriptor.HasField('authentication')

  experiments = sdk_pipeline_options.view_as(DebugOptions).experiments or []
  enable_heap_dump = 'enable_heap_dump' in experiments

  beam_plugins = sdk_pipeline_options.view_as(SetupOptions).beam_plugins or []
  _import_beam_plugins(beam_plugins)

  if dry_run:
    return

  data_sampler = DataSampler.create(sdk_pipeline_options)

  sdk_harness = SdkHarness(
      control_address=control_service_descriptor.url,
      status_address=status_service_descriptor.url,
      worker_id=_worker_id,
      state_cache_size=_get_state_cache_size_bytes(
          options=sdk_pipeline_options),
      data_buffer_time_limit_ms=_get_data_buffer_time_limit_ms(experiments),
      profiler_factory=profiler.Profile.factory_from_options(
          sdk_pipeline_options.view_as(ProfilingOptions)),
      enable_heap_dump=enable_heap_dump,
      data_sampler=data_sampler,
      deferred_exception=deferred_exception,
      runner_capabilities=runner_capabilities,
      element_processing_timeout_minutes=sdk_pipeline_options.view_as(
          WorkerOptions).element_processing_timeout_minutes)
  return fn_log_handler, sdk_harness, sdk_pipeline_options


def _start_profiler(gcp_profiler_service_name, gcp_profiler_service_version):
  try:
    import googlecloudprofiler
    if gcp_profiler_service_name and gcp_profiler_service_version:
      googlecloudprofiler.start(
          service=gcp_profiler_service_name,
          service_version=gcp_profiler_service_version,
          verbose=1)
      _LOGGER.info('Turning on Google Cloud Profiler.')
    else:
      raise RuntimeError('Unable to find the job id or job name from envvar.')
  except Exception as e:  # pylint: disable=broad-except
    _LOGGER.warning(
        'Unable to start google cloud profiler due to error: %s. For how to '
        'enable Cloud Profiler with Dataflow see '
        'https://cloud.google.com/dataflow/docs/guides/profiling-a-pipeline.'
        'For troubleshooting tips with Cloud Profiler see '
        'https://cloud.google.com/profiler/docs/troubleshooting.' % e)


def _get_gcp_profiler_name_if_enabled(sdk_pipeline_options):
  gcp_profiler_service_name = sdk_pipeline_options.view_as(
      GoogleCloudOptions).get_cloud_profiler_service_name()

  return gcp_profiler_service_name


def main(unused_argv):
  """Main entry point for SDK Fn Harness."""
  (fn_log_handler, sdk_harness,
   sdk_pipeline_options) = create_harness(os.environ)
  global _FN_LOG_HANDLER
  if fn_log_handler:
    _FN_LOG_HANDLER = fn_log_handler
  gcp_profiler_name = _get_gcp_profiler_name_if_enabled(sdk_pipeline_options)
  if gcp_profiler_name:
    _start_profiler(gcp_profiler_name, os.environ["JOB_ID"])

  try:
    _LOGGER.info('Python sdk harness starting.')
    sdk_harness.run()
    _LOGGER.info('Python sdk harness exiting.')
  except:  # pylint: disable=broad-except
    _LOGGER.critical('Python sdk harness failed: ', exc_info=True)
    raise
  finally:
    if fn_log_handler:
      fn_log_handler.close()


def terminate_sdk_harness():
  """Flushes the FnApiLogRecordHandler if it exists."""
  _LOGGER.error('The SDK harness will be terminated in 5 seconds.')
  time.sleep(5)
  if _FN_LOG_HANDLER:
    _FN_LOG_HANDLER.close()
  os.kill(os.getpid(), signal.SIGINT)
  # Delay further control flow in the caller until process is terminated.
  time.sleep(60)
  # Try to force-terminate if still running.
  os.kill(os.getpid(), signal.SIGKILL)


def _load_pipeline_options(options_json):
  if options_json is None:
    return {}
  options = json.loads(options_json)
  # Check the options field first for backward compatibility.
  if 'options' in options:
    return options.get('options')
  else:
    # Remove extra urn part from the key.
    portable_option_regex = r'^beam:option:(?P<key>.*):v1$'
    return {
        re.match(portable_option_regex, k).group('key') if re.match(
            portable_option_regex, k) else k: v
        for k, v in options.items()
    }


def _parse_pipeline_options(options_json):
  return PipelineOptions.from_dictionary(_load_pipeline_options(options_json))


def _get_state_cache_size_bytes(options):
  """Return the maximum size of state cache in bytes.

  Returns:
    an int indicating the maximum number of bytes to cache.
  """
  max_cache_memory_usage_mb = options.view_as(
      WorkerOptions).max_cache_memory_usage_mb
  # to maintain backward compatibility
  experiments = options.view_as(DebugOptions).experiments or []
  for experiment in experiments:
    # There should only be 1 match so returning from the loop
    if re.match(r'state_cache_size=', experiment):
      _LOGGER.warning(
          '--experiments=state_cache_size=X is deprecated and will be removed '
          'in future releases.'
          'Please use --max_cache_memory_usage_mb=X to set the cache size for '
          'user state API and side inputs.')
      return int(
          re.match(r'state_cache_size=(?P<state_cache_size>.*)',
                   experiment).group('state_cache_size')) << 20
  return max_cache_memory_usage_mb << 20


def _get_data_buffer_time_limit_ms(experiments):
  """Defines the time limt of the outbound data buffering.

  Note: data_buffer_time_limit_ms is an experimental flag and might
  not be available in future releases.

  Returns:
    an int indicating the time limit in milliseconds of the outbound
      data buffering. Default is 0 (disabled)
  """

  for experiment in experiments:
    # There should only be 1 match so returning from the loop
    if re.match(r'data_buffer_time_limit_ms=', experiment):
      return int(
          re.match(
              r'data_buffer_time_limit_ms=(?P<data_buffer_time_limit_ms>.*)',
              experiment).group('data_buffer_time_limit_ms'))
  return 0


def _get_log_level_from_options_dict(options_dict: dict) -> int:
  """Get log level from options dict's entry `default_sdk_harness_log_level`.
  If not specified, default log level is logging.INFO.
  """
  dict_level = options_dict.get('default_sdk_harness_log_level', 'INFO')
  log_level = dict_level
  if log_level.isdecimal():
    log_level = int(log_level)
  else:
    # labeled log level
    log_level = getattr(logging, log_level, None)
    if not isinstance(log_level, int):
      # unknown log level.
      _LOGGER.error("Unknown log level %s. Use default value INFO.", dict_level)
      log_level = logging.INFO

  return log_level


def _set_log_level_overrides(options_dict: dict) -> None:
  """Set module log level overrides from options dict's entry
  `sdk_harness_log_level_overrides`.
  """
  parsed_overrides = options_dict.get('sdk_harness_log_level_overrides', None)

  if not isinstance(parsed_overrides, dict):
    if parsed_overrides is not None:
      _LOGGER.error(
          "Unable to parse sdk_harness_log_level_overrides: %s",
          parsed_overrides)
    return

  for module_name, log_level in parsed_overrides.items():
    try:
      logging.getLogger(module_name).setLevel(log_level)
    except Exception as e:
      # Never crash the worker when exception occurs during log level setting
      # but logging the error.
      _LOGGER.error(
          "Error occurred when setting log level for %s: %s", module_name, e)


class LoadMainSessionException(Exception):
  """
  Used to crash this worker if a main session file failed to load.
  """
  pass


def _load_main_session(semi_persistent_directory):
  """Loads a pickled main session from the path specified."""
  if semi_persistent_directory:
    session_file = os.path.join(
        semi_persistent_directory, 'staged', names.PICKLED_MAIN_SESSION_FILE)
    if os.path.isfile(session_file):
      # If the expected session file is present but empty, it's likely that
      # the user code run by this worker will likely crash at runtime.
      # This can happen if the worker fails to download the main session.
      # Raise a fatal error and crash this worker, forcing a restart.
      if os.path.getsize(session_file) == 0:
        # Potenitally transient error, unclear if still happening.
        raise LoadMainSessionException(
            'Session file found, but empty: %s. Functions defined in __main__ '
            '(interactive session) will almost certainly fail.' %
            (session_file, ))
      pickler.load_session(session_file)
    else:
      _LOGGER.warning(
          'No session file found: %s. Functions defined in __main__ '
          '(interactive session) may fail.',
          session_file)
  else:
    _LOGGER.warning(
        'No semi_persistent_directory found: Functions defined in __main__ '
        '(interactive session) may fail.')


if __name__ == '__main__':
  main(sys.argv)
